Example: Pruning¶
This example shows an advanced example on how to use hyperparameter tuning with pruning.
Import the breast cancer dataset from sklearn.datasets. This is a small and easy to train dataset whose goal is to predict whether a patient has breast cancer or not.
Load the data¶
InĀ [1]:
Copied!
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
# Import packages
from sklearn.datasets import load_breast_cancer
from optuna.pruners import HyperbandPruner
from atom import ATOMClassifier
InĀ [2]:
Copied!
# Load the data
X, y = load_breast_cancer(return_X_y=True)
# Load the data
X, y = load_breast_cancer(return_X_y=True)
Run the pipeline¶
InĀ [3]:
Copied!
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
# Initialize atom
atom = ATOMClassifier(X, y, verbose=2, random_state=1)
<< ================== ATOM ================== >> Configuration ==================== >> Algorithm task: Binary classification. Dataset stats ==================== >> Shape: (569, 31) Train set size: 456 Test set size: 113 ------------------------------------- Memory: 141.24 kB Scaled: False Outlier values: 167 (1.2%)
InĀ [4]:
Copied!
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
# Use ht_params to specify a custom pruner
# Note that pruned trials show the number of iterations it completed
atom.run(
models="SGD",
metric="f1",
n_trials=25,
ht_params={
"distributions": ["penalty", "max_iter"],
"pruner": HyperbandPruner(),
}
)
Training ========================= >> Models: SGD Metric: f1 Running hyperparameter tuning for StochasticGradientDescent... | trial | penalty | max_iter | f1 | best_f1 | time_trial | time_ht | state | | ----- | ------- | -------- | ------- | ------- | ---------- | ------- | -------- | | 0 | l1 | 650 | 0.9558 | 0.9558 | 2.801s | 2.801s | COMPLETE | | 1 | elast.. | 1050 | 0.9744 | 0.9744 | 4.590s | 7.390s | COMPLETE | | 2 | elast.. | 500 | 0.9828 | 0.9828 | 0.033s | 7.423s | PRUNED | | 3 | None | 700 | 0.9739 | 0.9828 | 2.951s | 10.374s | COMPLETE | | 4 | l1 | 1400 | 0.9735 | 0.9828 | 0.033s | 10.407s | PRUNED | | 5 | None | 1400 | 0.9735 | 0.9828 | 5.994s | 16.401s | COMPLETE | | 6 | l2 | 1200 | 0.9825 | 0.9828 | 5.246s | 21.647s | COMPLETE | | 7 | l2 | 1250 | 0.9825 | 0.9828 | 5.436s | 27.083s | COMPLETE | | 8 | None | 600 | 0.9828 | 0.9828 | 0.023s | 27.106s | PRUNED | | 9 | l1 | 600 | 0.9402 | 0.9828 | 0.030s | 27.136s | PRUNED | | 10 | l2 | 950 | 0.9565 | 0.9828 | 4.118s | 31.254s | COMPLETE | | 11 | l2 | 1200 | 0.9825 | 0.9828 | 0.005s | 31.259s | COMPLETE | | 12 | l2 | 1200 | 0.9825 | 0.9828 | 0.005s | 31.264s | COMPLETE | | 13 | l2 | 1200 | 0.9825 | 0.9828 | 0.005s | 31.269s | COMPLETE | | 14 | l2 | 1500 | 0.9573 | 0.9828 | 0.038s | 31.306s | PRUNED | | 15 | l2 | 950 | 0.9565 | 0.9828 | 0.005s | 31.311s | COMPLETE | | 16 | l2 | 1100 | 0.9391 | 0.9828 | 0.040s | 31.351s | PRUNED | | 17 | l2 | 850 | 0.9831 | 0.9831 | 0.030s | 31.381s | PRUNED | | 18 | elast.. | 1300 | 0.931 | 0.9831 | 0.029s | 31.410s | PRUNED | | 19 | l2 | 1300 | 0.9649 | 0.9831 | 0.067s | 31.478s | PRUNED | | 20 | l2 | 800 | 0.9661 | 0.9831 | 0.039s | 31.517s | PRUNED | | 21 | l2 | 1150 | 0.9402 | 0.9831 | 0.032s | 31.548s | PRUNED | | 22 | l2 | 1300 | 0.9573 | 0.9831 | 0.038s | 31.586s | PRUNED | | 23 | l2 | 1250 | 0.9825 | 0.9831 | 0.008s | 31.594s | COMPLETE | | 24 | l2 | 1050 | 0.9565 | 0.9831 | 0.070s | 31.665s | PRUNED | Hyperparameter tuning --------------------------- Best trial --> 6 Best parameters: --> penalty: l2 --> max_iter: 1200 Best evaluation --> f1: 0.9825 Time elapsed: 31.665s Fit --------------------------------------------- Train evaluation --> f1: 0.993 Test evaluation --> f1: 0.9722 Time elapsed: 8.384s ------------------------------------------------- Time: 40.049s Final results ==================== >> Total time: 40.301s ------------------------------------- StochasticGradientDescent --> f1: 0.9722
Analyze the results¶
InĀ [5]:
Copied!
atom.plot_trials()
atom.plot_trials()
InĀ [6]:
Copied!
atom.plot_hyperparameter_importance()
atom.plot_hyperparameter_importance()